agentmux_srv\backend\wshutil/
proxy.rs1#![allow(dead_code)]
2use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use serde::{Deserialize, Serialize};
12use tokio::sync::mpsc;
13use super::osc::{DEFAULT_INPUT_CH_SIZE, DEFAULT_OUTPUT_CH_SIZE};
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct RpcContext {
18 #[serde(default, skip_serializing_if = "String::is_empty", rename = "blockid")]
19 pub block_id: String,
20 #[serde(default, skip_serializing_if = "String::is_empty", rename = "tabid")]
21 pub tab_id: String,
22 #[serde(default, skip_serializing_if = "String::is_empty", rename = "conn")]
23 pub conn: String,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RpcMessage {
29 #[serde(default, skip_serializing_if = "String::is_empty")]
30 pub command: String,
31 #[serde(default, skip_serializing_if = "String::is_empty", rename = "reqid")]
32 pub req_id: String,
33 #[serde(default, skip_serializing_if = "String::is_empty", rename = "resid")]
34 pub res_id: String,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
36 pub data: Option<serde_json::Value>,
37 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub error: Option<String>,
39 #[serde(default)]
40 pub cont: bool,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
42 pub cancel: Option<bool>,
43 #[serde(default, skip_serializing_if = "Option::is_none", rename = "route")]
44 pub route: Option<String>,
45 #[serde(default, skip_serializing_if = "Option::is_none", rename = "source")]
46 pub source: Option<String>,
47 #[serde(default, skip_serializing_if = "Option::is_none", rename = "authtoken")]
48 pub auth_token: Option<String>,
49 #[serde(default, skip_serializing_if = "Option::is_none", rename = "timeout")]
50 pub timeout: Option<u64>,
51}
52
53impl RpcMessage {
54 pub fn is_request(&self) -> bool {
56 !self.command.is_empty() && !self.req_id.is_empty()
57 }
58
59 pub fn is_response(&self) -> bool {
61 !self.res_id.is_empty()
62 }
63
64 pub fn is_error(&self) -> bool {
66 self.error.is_some()
67 }
68
69 pub fn is_final(&self) -> bool {
71 !self.cont
72 }
73}
74
75pub struct WshRpcProxy {
78 rpc_context: Arc<Mutex<Option<RpcContext>>>,
79 auth_token: Arc<Mutex<String>>,
80 pub to_remote: mpsc::Sender<Vec<u8>>,
81 pub from_remote: mpsc::Receiver<Vec<u8>>,
82 to_remote_rx: Option<mpsc::Receiver<Vec<u8>>>,
83 from_remote_tx: mpsc::Sender<Vec<u8>>,
84}
85
86impl WshRpcProxy {
87 pub fn new() -> Self {
88 let (to_remote_tx, to_remote_rx) = mpsc::channel(DEFAULT_INPUT_CH_SIZE);
89 let (from_remote_tx, from_remote_rx) = mpsc::channel(DEFAULT_OUTPUT_CH_SIZE);
90 Self {
91 rpc_context: Arc::new(Mutex::new(None)),
92 auth_token: Arc::new(Mutex::new(String::new())),
93 to_remote: to_remote_tx,
94 from_remote: from_remote_rx,
95 to_remote_rx: Some(to_remote_rx),
96 from_remote_tx,
97 }
98 }
99
100 pub fn set_rpc_context(&self, ctx: RpcContext) {
101 *self.rpc_context.lock().unwrap() = Some(ctx);
102 }
103
104 pub fn get_rpc_context(&self) -> Option<RpcContext> {
105 self.rpc_context.lock().unwrap().clone()
106 }
107
108 pub fn set_auth_token(&self, token: &str) {
109 *self.auth_token.lock().unwrap() = token.to_string();
110 }
111
112 pub fn get_auth_token(&self) -> String {
113 self.auth_token.lock().unwrap().clone()
114 }
115
116 pub fn take_to_remote_rx(&mut self) -> Option<mpsc::Receiver<Vec<u8>>> {
118 self.to_remote_rx.take()
119 }
120
121 pub fn from_remote_sender(&self) -> mpsc::Sender<Vec<u8>> {
123 self.from_remote_tx.clone()
124 }
125
126 pub async fn send_to_remote(&self, msg: Vec<u8>) -> Result<(), String> {
128 self.to_remote
129 .send(msg)
130 .await
131 .map_err(|e| format!("failed to send to remote: {}", e))
132 }
133
134 pub async fn send_rpc_message(&self, msg: &RpcMessage) -> Result<(), String> {
136 let json = serde_json::to_vec(msg).map_err(|e| format!("json encode: {}", e))?;
137 self.send_to_remote(json).await
138 }
139
140 pub async fn send_response_error(&self, req_id: &str, err_msg: &str) -> Result<(), String> {
142 if req_id.is_empty() {
143 return Ok(());
144 }
145 let msg = RpcMessage {
146 res_id: req_id.to_string(),
147 error: Some(err_msg.to_string()),
148 ..Default::default()
149 };
150 self.send_rpc_message(&msg).await
151 }
152}
153
154impl Default for RpcMessage {
155 fn default() -> Self {
156 Self {
157 command: String::new(),
158 req_id: String::new(),
159 res_id: String::new(),
160 data: None,
161 error: None,
162 cont: false,
163 cancel: None,
164 route: None,
165 source: None,
166 auth_token: None,
167 timeout: None,
168 }
169 }
170}
171
172pub struct WshMultiProxy {
175 proxies: Arc<Mutex<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
176}
177
178impl WshMultiProxy {
179 pub fn new() -> Self {
180 Self {
181 proxies: Arc::new(Mutex::new(HashMap::new())),
182 }
183 }
184
185 pub fn add_proxy(&self, name: &str, sender: mpsc::Sender<Vec<u8>>) {
187 self.proxies.lock().unwrap().insert(name.to_string(), sender);
188 }
189
190 pub fn remove_proxy(&self, name: &str) {
192 self.proxies.lock().unwrap().remove(name);
193 }
194
195 pub async fn broadcast(&self, msg: Vec<u8>) {
197 let senders: Vec<mpsc::Sender<Vec<u8>>> = {
198 let proxies = self.proxies.lock().unwrap();
199 proxies.values().cloned().collect()
200 };
201
202 for sender in senders {
203 let msg_clone = msg.clone();
204 let _ = sender.send(msg_clone).await;
205 }
206 }
207
208 pub async fn broadcast_rpc_message(&self, msg: &RpcMessage) -> Result<(), String> {
210 let json = serde_json::to_vec(msg).map_err(|e| format!("json encode: {}", e))?;
211 self.broadcast(json).await;
212 Ok(())
213 }
214
215 pub fn proxy_count(&self) -> usize {
217 self.proxies.lock().unwrap().len()
218 }
219
220 pub fn proxy_names(&self) -> Vec<String> {
222 self.proxies.lock().unwrap().keys().cloned().collect()
223 }
224}
225
226impl Default for WshMultiProxy {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_rpc_message_request() {
238 let msg = RpcMessage {
239 command: "test".to_string(),
240 req_id: "abc123".to_string(),
241 ..Default::default()
242 };
243 assert!(msg.is_request());
244 assert!(!msg.is_response());
245 assert!(!msg.is_error());
246 assert!(msg.is_final());
247 }
248
249 #[test]
250 fn test_rpc_message_response() {
251 let msg = RpcMessage {
252 res_id: "abc123".to_string(),
253 data: Some(serde_json::json!({"result": "ok"})),
254 ..Default::default()
255 };
256 assert!(!msg.is_request());
257 assert!(msg.is_response());
258 assert!(!msg.is_error());
259 }
260
261 #[test]
262 fn test_rpc_message_error() {
263 let msg = RpcMessage {
264 res_id: "abc123".to_string(),
265 error: Some("something failed".to_string()),
266 ..Default::default()
267 };
268 assert!(msg.is_error());
269 }
270
271 #[test]
272 fn test_rpc_message_serde() {
273 let msg = RpcMessage {
274 command: "getblock".to_string(),
275 req_id: "req-1".to_string(),
276 data: Some(serde_json::json!({"id": "block-1"})),
277 route: Some("conn:local".to_string()),
278 ..Default::default()
279 };
280 let json = serde_json::to_string(&msg).unwrap();
281 let parsed: RpcMessage = serde_json::from_str(&json).unwrap();
282 assert_eq!(parsed.command, "getblock");
283 assert_eq!(parsed.req_id, "req-1");
284 assert_eq!(parsed.route.unwrap(), "conn:local");
285 }
286
287 #[tokio::test]
288 async fn test_multi_proxy_broadcast() {
289 let multi = WshMultiProxy::new();
290 let (tx1, mut rx1) = mpsc::channel(10);
291 let (tx2, mut rx2) = mpsc::channel(10);
292
293 multi.add_proxy("conn1", tx1);
294 multi.add_proxy("conn2", tx2);
295 assert_eq!(multi.proxy_count(), 2);
296
297 multi.broadcast(b"hello".to_vec()).await;
298
299 let msg1 = rx1.recv().await.unwrap();
300 let msg2 = rx2.recv().await.unwrap();
301 assert_eq!(msg1, b"hello");
302 assert_eq!(msg2, b"hello");
303 }
304
305 #[test]
306 fn test_multi_proxy_add_remove() {
307 let multi = WshMultiProxy::new();
308 let (tx, _rx) = mpsc::channel(10);
309
310 multi.add_proxy("conn1", tx);
311 assert_eq!(multi.proxy_count(), 1);
312
313 multi.remove_proxy("conn1");
314 assert_eq!(multi.proxy_count(), 0);
315 }
316}